package drds.plus.executor.function.aggregate_function;

import drds.plus.executor.ExecuteContext;
import drds.plus.sql_process.abstract_syntax_tree.expression.item.function.Function;
import drds.plus.sql_process.type.Type;

/**
 * avg函数处理比较特殊，会将avg转化为sum + count，拿到所有库的数据后再计算avg
 */
public class Avg extends AggregateFunction {

    private Long count = 0L;
    private Object total = null;

    public void map(ExecuteContext executeContext, Object[] args) {
        count++;
        Object arg = args[0];

        Type type = getSumType();
        if (arg != null) {
            if (total == null) {
                total = type.convert(arg);
            } else {
                total = type.getCalculator().add(total, arg);
            }
        }
    }

    public void reduce(ExecuteContext executeContext, Object[] args) {
        if (args[0] == null || args[1] == null) {
            return;
        }

        count += Type.LongType.convert(args[1]);
        Object arg = args[0];
        Type type = getSumType();
        if (total == null) {
            total = type.convert(arg);
        } else {
            total = type.getCalculator().add(total, arg);
        }
    }

    public Object getResult() {
        Type type = this.getReturnType();
        if (total == null) {
            return type.getCalculator().divide(0L, count);
        } else {
            return type.getCalculator().divide(total, count);
        }
    }

    public void clear() {
        this.total = null;
        this.count = 0L;
    }

    public Type getReturnType() {
        return getMapReturnType();
    }

    public Type getMapReturnType() {
        Type type = getFirstArgType();
        if (type == Type.BigIntegerType) {
            // 如果是大整数，返回bigDecimal
            return Type.BigDecimalType;
        } else {
            // 尽可能都返回为BigDecimalType，double类型容易出现精度问题，会和mysql出现误差
            return Type.BigDecimalType;
        }
    }

    public Type getSumType() {
        Type type = getFirstArgType();
        if (type == Type.IntegerType || type == Type.ShortType) {
            return Type.LongType;
        } else {
            return type;
        }
    }

    public String getDataBaseFunction() {
        return bulidAvgSql(function);
    }

    private String bulidAvgSql(Function function) {
        String columnName = function.getColumnName();
        StringBuilder sb = new StringBuilder();
        if (function.getAlias() != null) {// 如果有别名，需要和FuckAvgOptimizer中保持一致
            sb.append(function.getAlias() + "1").append(",").append(function.getAlias() + "2");
        } else {
            sb.append(columnName.replace("avg", "sum"));
            sb.append(",").append(columnName.replace("avg", "count"));
        }
        return sb.toString();
    }


    public String[] getFunctionNames() {
        return new String[]{"avg"};
    }
}
